// Copyright © 2023-2024 Apple Inc.

#pragma once

#include <optional>

#include "mlx/array.h"
#include "mlx/device.h"
#include "mlx/stream.h"
#include "mlx/utils.h"

namespace mlx::core {

/**
 * \defgroup ops Core array operations
 * @{
 */

/**
 * A 1D array of numbers starting at `start` (optional),
 * stopping at stop, stepping by `step` (optional). */
array arange(
    double start,
    double stop,
    double step,
    Dtype dtype,
    StreamOrDevice s = {});
array arange(double start, double stop, double step, StreamOrDevice s = {});
array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {});
array arange(double start, double stop, StreamOrDevice s = {});
array arange(double stop, Dtype dtype, StreamOrDevice s = {});
array arange(double stop, StreamOrDevice s = {});

array arange(int start, int stop, int step, StreamOrDevice s = {});
array arange(int start, int stop, StreamOrDevice s = {});
array arange(int stop, StreamOrDevice s = {});

/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */
array linspace(
    double start,
    double stop,
    int num = 50,
    Dtype dtype = float32,
    StreamOrDevice s = {});

/** Convert an array to the given data type. */
array astype(array a, Dtype dtype, StreamOrDevice s = {});

/** Create a view of an array with the given shape and strides. */
array as_strided(
    array a,
    std::vector<int> shape,
    std::vector<size_t> strides,
    size_t offset,
    StreamOrDevice s = {});

/** Copy another array. */
array copy(array a, StreamOrDevice s = {});

/** Fill an array of the given shape with the given value(s). */
array full(
    std::vector<int> shape,
    array vals,
    Dtype dtype,
    StreamOrDevice s = {});
array full(std::vector<int> shape, array vals, StreamOrDevice s = {});
template <typename T>
array full(std::vector<int> shape, T val, Dtype dtype, StreamOrDevice s = {}) {
  return full(std::move(shape), array(val, dtype), to_stream(s));
}
template <typename T>
array full(std::vector<int> shape, T val, StreamOrDevice s = {}) {
  return full(std::move(shape), array(val), to_stream(s));
}

/** Fill an array of the given shape with zeros. */
array zeros(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array zeros(const std::vector<int>& shape, StreamOrDevice s = {}) {
  return zeros(shape, float32, s);
}
array zeros_like(const array& a, StreamOrDevice s = {});

/** Fill an array of the given shape with ones. */
array ones(const std::vector<int>& shape, Dtype dtype, StreamOrDevice s = {});
inline array ones(const std::vector<int>& shape, StreamOrDevice s = {}) {
  return ones(shape, float32, s);
}
array ones_like(const array& a, StreamOrDevice s = {});

/** Fill an array of the given shape (n,m) with ones in the specified diagonal
 * k, and zeros everywhere else. */
array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {});
inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) {
  return eye(n, n, 0, dtype, s);
}
inline array eye(int n, int m, StreamOrDevice s = {}) {
  return eye(n, m, 0, float32, s);
}
inline array eye(int n, int m, int k, StreamOrDevice s = {}) {
  return eye(n, m, k, float32, s);
}
inline array eye(int n, StreamOrDevice s = {}) {
  return eye(n, n, 0, float32, s);
}

/** Create a square matrix of shape (n,n) of zeros, and ones in the major
 * diagonal. */
array identity(int n, Dtype dtype, StreamOrDevice s = {});
inline array identity(int n, StreamOrDevice s = {}) {
  return identity(n, float32, s);
}

array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {});
inline array tri(int n, Dtype type, StreamOrDevice s = {}) {
  return tri(n, n, 0, type, s);
}

array tril(array x, int k = 0, StreamOrDevice s = {});
array triu(array x, int k = 0, StreamOrDevice s = {});

/** Reshape an array to the given shape. */
array reshape(const array& a, std::vector<int> shape, StreamOrDevice s = {});

/** Flatten the dimensions in the range `[start_axis, end_axis]` . */
array flatten(
    const array& a,
    int start_axis,
    int end_axis = -1,
    StreamOrDevice s = {});

/** Flatten the array to 1D. */
array flatten(const array& a, StreamOrDevice s = {});

/** Remove singleton dimensions at the given axes. */
array squeeze(
    const array& a,
    const std::vector<int>& axes,
    StreamOrDevice s = {});

/** Remove singleton dimensions at the given axis. */
inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) {
  return squeeze(a, std::vector<int>{axis}, s);
}

/** Remove all singleton dimensions. */
array squeeze(const array& a, StreamOrDevice s = {});

/** Add a singleton dimension at the given axes. */
array expand_dims(
    const array& a,
    const std::vector<int>& axes,
    StreamOrDevice s = {});

/** Add a singleton dimension at the given axis. */
array expand_dims(const array& a, int axis, StreamOrDevice s = {});

/** Slice an array. */
array slice(
    const array& a,
    std::vector<int> start,
    std::vector<int> stop,
    std::vector<int> strides,
    StreamOrDevice s = {});

/** Slice an array with a stride of 1 in each dimension. */
array slice(
    const array& a,
    const std::vector<int>& start,
    const std::vector<int>& stop,
    StreamOrDevice s = {});

/** Update a slice from the source array */
array slice_update(
    const array& src,
    const array& update,
    std::vector<int> start,
    std::vector<int> stop,
    std::vector<int> strides,
    StreamOrDevice s = {});

/** Update a slice from the source array with stride 1 in each dimension */
array slice_update(
    const array& src,
    const array& update,
    std::vector<int> start,
    std::vector<int> stop,
    StreamOrDevice s = {});

/** Split an array into sub-arrays along a given axis. */
std::vector<array>
split(const array& a, int num_splits, int axis, StreamOrDevice s = {});
std::vector<array> split(const array& a, int num_splits, StreamOrDevice s = {});
std::vector<array> split(
    const array& a,
    const std::vector<int>& indices,
    int axis,
    StreamOrDevice s = {});
std::vector<array>
split(const array& a, const std::vector<int>& indices, StreamOrDevice s = {});

/** A vector of coordinate arrays from coordinate vectors. */
std::vector<array> meshgrid(
    const std::vector<array>& arrays,
    bool sparse = false,
    std::string indexing = "xy",
    StreamOrDevice s = {});

/**
 * Clip (limit) the values in an array.
 */
array clip(
    const array& a,
    const std::optional<array>& a_min = std::nullopt,
    const std::optional<array>& a_max = std::nullopt,
    StreamOrDevice s = {});

/** Concatenate arrays along a given axis. */
array concatenate(
    const std::vector<array>& arrays,
    int axis,
    StreamOrDevice s = {});
array concatenate(const std::vector<array>& arrays, StreamOrDevice s = {});

/** Stack arrays along a new axis. */
array stack(const std::vector<array>& arrays, int axis, StreamOrDevice s = {});
array stack(const std::vector<array>& arrays, StreamOrDevice s = {});

/** Repeat an array along an axis. */
array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {});
array repeat(const array& arr, int repeats, StreamOrDevice s = {});

array tile(const array& arr, std::vector<int> reps, StreamOrDevice s = {});

/** Permutes the dimensions according to the given axes. */
array transpose(const array& a, std::vector<int> axes, StreamOrDevice s = {});
inline array transpose(
    const array& a,
    std::initializer_list<int> axes,
    StreamOrDevice s = {}) {
  return transpose(a, std::vector<int>(axes), s);
}

/** Swap two axes of an array. */
array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {});

/** Move an axis of an array. */
array moveaxis(
    const array& a,
    int source,
    int destination,
    StreamOrDevice s = {});

/** Pad an array with a constant value */
array pad(
    const array& a,
    const std::vector<int>& axes,
    const std::vector<int>& low_pad_size,
    const std::vector<int>& high_pad_size,
    const array& pad_value = array(0),
    StreamOrDevice s = {});

/** Pad an array with a constant value along all axes */
array pad(
    const array& a,
    const std::vector<std::pair<int, int>>& pad_width,
    const array& pad_value = array(0),
    StreamOrDevice s = {});
array pad(
    const array& a,
    const std::pair<int, int>& pad_width,
    const array& pad_value = array(0),
    StreamOrDevice s = {});
array pad(
    const array& a,
    int pad_width,
    const array& pad_value = array(0),
    StreamOrDevice s = {});

/** Permutes the dimensions in reverse order. */
array transpose(const array& a, StreamOrDevice s = {});

/** Broadcast an array to a given shape. */
array broadcast_to(
    const array& a,
    const std::vector<int>& shape,
    StreamOrDevice s = {});

/** Broadcast a vector of arrays against one another. */
std::vector<array> broadcast_arrays(
    const std::vector<array>& inputs,
    StreamOrDevice s = {});

/** Returns the bool array with (a == b) element-wise. */
array equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator==(const array& a, const array& b) {
  return equal(a, b);
}
template <typename T>
array operator==(T a, const array& b) {
  return equal(array(a), b);
}
template <typename T>
array operator==(const array& a, T b) {
  return equal(a, array(b));
}

/** Returns the bool array with (a != b) element-wise. */
array not_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator!=(const array& a, const array& b) {
  return not_equal(a, b);
}
template <typename T>
array operator!=(T a, const array& b) {
  return not_equal(array(a), b);
}
template <typename T>
array operator!=(const array& a, T b) {
  return not_equal(a, array(b));
}

/** Returns bool array with (a > b) element-wise. */
array greater(const array& a, const array& b, StreamOrDevice s = {});
inline array operator>(const array& a, const array& b) {
  return greater(a, b);
}
template <typename T>
array operator>(T a, const array& b) {
  return greater(array(a), b);
}
template <typename T>
array operator>(const array& a, T b) {
  return greater(a, array(b));
}

/** Returns bool array with (a >= b) element-wise. */
array greater_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator>=(const array& a, const array& b) {
  return greater_equal(a, b);
}
template <typename T>
array operator>=(T a, const array& b) {
  return greater_equal(array(a), b);
}
template <typename T>
array operator>=(const array& a, T b) {
  return greater_equal(a, array(b));
}

/** Returns bool array with (a < b) element-wise. */
array less(const array& a, const array& b, StreamOrDevice s = {});
inline array operator<(const array& a, const array& b) {
  return less(a, b);
}
template <typename T>
array operator<(T a, const array& b) {
  return less(array(a), b);
}
template <typename T>
array operator<(const array& a, T b) {
  return less(a, array(b));
}

/** Returns bool array with (a <= b) element-wise. */
array less_equal(const array& a, const array& b, StreamOrDevice s = {});
inline array operator<=(const array& a, const array& b) {
  return less_equal(a, b);
}
template <typename T>
array operator<=(T a, const array& b) {
  return less_equal(array(a), b);
}
template <typename T>
array operator<=(const array& a, T b) {
  return less_equal(a, array(b));
}

/** True if two arrays have the same shape and elements. */
array array_equal(
    const array& a,
    const array& b,
    bool equal_nan,
    StreamOrDevice s = {});
inline array
array_equal(const array& a, const array& b, StreamOrDevice s = {}) {
  return array_equal(a, b, false, s);
}

array isnan(const array& a, StreamOrDevice s = {});

array isinf(const array& a, StreamOrDevice s = {});

array isposinf(const array& a, StreamOrDevice s = {});

array isneginf(const array& a, StreamOrDevice s = {});

/** Select from x or y depending on condition. */
array where(
    const array& condition,
    const array& x,
    const array& y,
    StreamOrDevice s = {});

/** True if all elements in the array are true (or non-zero). **/
array all(const array& a, bool keepdims, StreamOrDevice s = {});
inline array all(const array& a, StreamOrDevice s = {}) {
  return all(a, false, to_stream(s));
}

/** True if the two arrays are equal within the specified tolerance. */
array allclose(
    const array& a,
    const array& b,
    double rtol = 1e-5,
    double atol = 1e-8,
    bool equal_nan = false,
    StreamOrDevice s = {});

/** Returns a boolean array where two arrays are element-wise equal within the
 * specified tolerance. */
array isclose(
    const array& a,
    const array& b,
    double rtol = 1e-5,
    double atol = 1e-8,
    bool equal_nan = false,
    StreamOrDevice s = {});

/**
 *  Reduces the input along the given axes. An output value is true
 *  if all the corresponding inputs are true.
 **/
array all(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/**
 *  Reduces the input along the given axis. An output value is true
 *  if all the corresponding inputs are true.
 **/
array all(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** True if any elements in the array are true (or non-zero). **/
array any(const array& a, bool keepdims, StreamOrDevice s = {});
inline array any(const array& a, StreamOrDevice s = {}) {
  return any(a, false, to_stream(s));
}

/**
 *  Reduces the input along the given axes. An output value is true
 *  if any of the corresponding inputs are true.
 **/
array any(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/**
 *  Reduces the input along the given axis. An output value is true
 *  if any of the corresponding inputs are true.
 **/
array any(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Sums the elements of an array. */
array sum(const array& a, bool keepdims, StreamOrDevice s = {});
inline array sum(const array& a, StreamOrDevice s = {}) {
  return sum(a, false, to_stream(s));
}

/** Sums the elements of an array along the given axes. */
array sum(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Sums the elements of an array along the given axis. */
array sum(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Computes the mean of the elements of an array. */
array mean(const array& a, bool keepdims, StreamOrDevice s = {});
inline array mean(const array& a, StreamOrDevice s = {}) {
  return mean(a, false, to_stream(s));
}

/** Computes the mean of the elements of an array along the given axes */
array mean(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Computes the mean of the elements of an array along the given axis */
array mean(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Computes the variance of the elements of an array. */
array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array var(const array& a, StreamOrDevice s = {}) {
  return var(a, false, 0, to_stream(s));
}

/** Computes the variance of the elements of an array along the given
 * axes */
array var(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    int ddof = 0,
    StreamOrDevice s = {});

/** Computes the variance of the elements of an array along the given
 * axis */
array var(
    const array& a,
    int axis,
    bool keepdims = false,
    int ddof = 0,
    StreamOrDevice s = {});

/** Computes the standard deviation of the elements of an array. */
array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {});
inline array std(const array& a, StreamOrDevice s = {}) {
  return std(a, false, 0, to_stream(s));
}

/** Computes the standard deviatoin of the elements of an array along the given
 * axes */
array std(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    int ddof = 0,
    StreamOrDevice s = {});

/** Computes the standard deviation of the elements of an array along the given
 * axis */
array std(
    const array& a,
    int axis,
    bool keepdims = false,
    int ddof = 0,
    StreamOrDevice s = {});

/** The product of all elements of the array. */
array prod(const array& a, bool keepdims, StreamOrDevice s = {});
inline array prod(const array& a, StreamOrDevice s = {}) {
  return prod(a, false, to_stream(s));
}

/** The product of the elements of an array along the given axes. */
array prod(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/** The product of the elements of an array along the given axis. */
array prod(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** The maximum of all elements of the array. */
array max(const array& a, bool keepdims, StreamOrDevice s = {});
inline array max(const array& a, StreamOrDevice s = {}) {
  return max(a, false, to_stream(s));
}

/** The maximum of the elements of an array along the given axes. */
array max(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/** The maximum of the elements of an array along the given axis. */
array max(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** The minimum of all elements of the array. */
array min(const array& a, bool keepdims, StreamOrDevice s = {});
inline array min(const array& a, StreamOrDevice s = {}) {
  return min(a, false, to_stream(s));
}

/** The minimum of the elements of an array along the given axes. */
array min(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/** The minimum of the elements of an array along the given axis. */
array min(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Returns the index of the minimum value in the array. */
array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
inline array argmin(const array& a, StreamOrDevice s = {}) {
  return argmin(a, false, s);
}

/** Returns the indices of the minimum values along a given axis. */
array argmin(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Returns the index of the maximum value in the array. */
array argmax(const array& a, bool keepdims, StreamOrDevice s = {});
inline array argmax(const array& a, StreamOrDevice s = {}) {
  return argmax(a, false, s);
}

/** Returns the indices of the maximum values along a given axis. */
array argmax(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Returns a sorted copy of the flattened array. */
array sort(const array& a, StreamOrDevice s = {});

/** Returns a sorted copy of the array along a given axis. */
array sort(const array& a, int axis, StreamOrDevice s = {});

/** Returns indices that sort the flattened array. */
array argsort(const array& a, StreamOrDevice s = {});

/** Returns indices that sort the array along a given axis. */
array argsort(const array& a, int axis, StreamOrDevice s = {});

/**
 * Returns a partitioned copy of the flattened array
 * such that the smaller kth elements are first.
 **/
array partition(const array& a, int kth, StreamOrDevice s = {});

/**
 * Returns a partitioned copy of the array along a given axis
 * such that the smaller kth elements are first.
 **/
array partition(const array& a, int kth, int axis, StreamOrDevice s = {});

/**
 * Returns indices that partition the flattened array
 * such that the smaller kth elements are first.
 **/
array argpartition(const array& a, int kth, StreamOrDevice s = {});

/**
 * Returns indices that partition the array along a given axis
 * such that the smaller kth elements are first.
 **/
array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {});

/** Returns topk elements of the flattened array. */
array topk(const array& a, int k, StreamOrDevice s = {});

/** Returns topk elements of the array along a given axis. */
array topk(const array& a, int k, int axis, StreamOrDevice s = {});

/** The logsumexp of all elements of the array. */
array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {});
inline array logsumexp(const array& a, StreamOrDevice s = {}) {
  return logsumexp(a, false, to_stream(s));
}

/** The logsumexp of the elements of an array along the given axes. */
array logsumexp(
    const array& a,
    const std::vector<int>& axes,
    bool keepdims = false,
    StreamOrDevice s = {});

/** The logsumexp of the elements of an array along the given axis. */
array logsumexp(
    const array& a,
    int axis,
    bool keepdims = false,
    StreamOrDevice s = {});

/** Absolute value of elements in an array. */
array abs(const array& a, StreamOrDevice s = {});

/** Negate an array. */
array negative(const array& a, StreamOrDevice s = {});
array operator-(const array& a);

/** The sign of the elements in an array. */
array sign(const array& a, StreamOrDevice s = {});

/** Logical not of an array */
array logical_not(const array& a, StreamOrDevice s = {});

/** Logical and of two arrays */
array logical_and(const array& a, const array& b, StreamOrDevice s = {});
array operator&&(const array& a, const array& b);

/** Logical or of two arrays */
array logical_or(const array& a, const array& b, StreamOrDevice s = {});
array operator||(const array& a, const array& b);

/** The reciprocal (1/x) of the elements in an array. */
array reciprocal(const array& a, StreamOrDevice s = {});

/** Add two arrays. */
array add(const array& a, const array& b, StreamOrDevice s = {});
array operator+(const array& a, const array& b);
template <typename T>
array operator+(T a, const array& b) {
  return add(array(a), b);
}
template <typename T>
array operator+(const array& a, T b) {
  return add(a, array(b));
}

/** Subtract two arrays. */
array subtract(const array& a, const array& b, StreamOrDevice s = {});
array operator-(const array& a, const array& b);
template <typename T>
array operator-(T a, const array& b) {
  return subtract(array(a), b);
}
template <typename T>
array operator-(const array& a, T b) {
  return subtract(a, array(b));
}

/** Multiply two arrays. */
array multiply(const array& a, const array& b, StreamOrDevice s = {});
array operator*(const array& a, const array& b);
template <typename T>
array operator*(T a, const array& b) {
  return multiply(array(a), b);
}
template <typename T>
array operator*(const array& a, T b) {
  return multiply(a, array(b));
}

/** Divide two arrays. */
array divide(const array& a, const array& b, StreamOrDevice s = {});
array operator/(const array& a, const array& b);
array operator/(double a, const array& b);
array operator/(const array& a, double b);

/** Compute the element-wise quotient and remainder. */
std::vector<array>
divmod(const array& a, const array& b, StreamOrDevice s = {});

/** Compute integer division. Equivalent to doing floor(a / x). */
array floor_divide(const array& a, const array& b, StreamOrDevice s = {});

/** Compute the element-wise remainder of division */
array remainder(const array& a, const array& b, StreamOrDevice s = {});
array operator%(const array& a, const array& b);
template <typename T>
array operator%(T a, const array& b) {
  return remainder(array(a), b);
}
template <typename T>
array operator%(const array& a, T b) {
  return remainder(a, array(b));
}

/** Element-wise maximum between two arrays. */
array maximum(const array& a, const array& b, StreamOrDevice s = {});

/** Element-wise minimum between two arrays. */
array minimum(const array& a, const array& b, StreamOrDevice s = {});

/** Floor the element of an array. **/
array floor(const array& a, StreamOrDevice s = {});

/** Ceil the element of an array. **/
array ceil(const array& a, StreamOrDevice s = {});

/** Square the elements of an array. */
array square(const array& a, StreamOrDevice s = {});

/** Exponential of the elements of an array. */
array exp(const array& a, StreamOrDevice s = {});

/** Sine of the elements of an array */
array sin(const array& a, StreamOrDevice s = {});

/** Cosine of the elements of an array */
array cos(const array& a, StreamOrDevice s = {});

/** Tangent of the elements of an array */
array tan(const array& a, StreamOrDevice s = {});

/** Arc Sine of the elements of an array */
array arcsin(const array& a, StreamOrDevice s = {});

/** Arc Cosine of the elements of an array */
array arccos(const array& a, StreamOrDevice s = {});

/** Arc Tangent of the elements of an array */
array arctan(const array& a, StreamOrDevice s = {});

/** Inverse tangent of the ratio of two arrays */
array arctan2(const array& a, const array& b, StreamOrDevice s = {});

/** Hyperbolic Sine of the elements of an array */
array sinh(const array& a, StreamOrDevice s = {});

/** Hyperbolic Cosine of the elements of an array */
array cosh(const array& a, StreamOrDevice s = {});

/** Hyperbolic Tangent of the elements of an array */
array tanh(const array& a, StreamOrDevice s = {});

/** Inverse Hyperbolic Sine of the elements of an array */
array arcsinh(const array& a, StreamOrDevice s = {});

/** Inverse Hyperbolic Cosine of the elements of an array */
array arccosh(const array& a, StreamOrDevice s = {});

/** Inverse Hyperbolic Tangent of the elements of an array */
array arctanh(const array& a, StreamOrDevice s = {});

/** Convert the elements of an array from Radians to Degrees **/
array degrees(const array& a, StreamOrDevice s = {});

/** Convert the elements of an array from Degrees to Radians **/
array radians(const array& a, StreamOrDevice s = {});

/** Natural logarithm of the elements of an array. */
array log(const array& a, StreamOrDevice s = {});

/** Log base 2 of the elements of an array. */
array log2(const array& a, StreamOrDevice s = {});

/** Log base 10 of the elements of an array. */
array log10(const array& a, StreamOrDevice s = {});

/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */
array log1p(const array& a, StreamOrDevice s = {});

/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */
array logaddexp(const array& a, const array& b, StreamOrDevice s = {});

/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */
array sigmoid(const array& a, StreamOrDevice s = {});

/** Computes the error function of the elements of an array. */
array erf(const array& a, StreamOrDevice s = {});

/** Computes the inverse error function of the elements of an array. */
array erfinv(const array& a, StreamOrDevice s = {});

/** Computes the expm1 function of the elements of an array. */
array expm1(const array& a, StreamOrDevice s = {});

/** Stop the flow of gradients. */
array stop_gradient(const array& a, StreamOrDevice s = {});

/** Round a floating point number */
array round(const array& a, int decimals, StreamOrDevice s = {});
inline array round(const array& a, StreamOrDevice s = {}) {
  return round(a, 0, s);
}

/** Matrix-matrix multiplication. */
array matmul(const array& a, const array& b, StreamOrDevice s = {});

/** Gather array entries given indices and slices */
array gather(
    const array& a,
    const std::vector<array>& indices,
    const std::vector<int>& axes,
    const std::vector<int>& slice_sizes,
    StreamOrDevice s = {});
inline array gather(
    const array& a,
    const array& indices,
    int axis,
    const std::vector<int>& slice_sizes,
    StreamOrDevice s = {}) {
  return gather(a, {indices}, std::vector<int>{axis}, slice_sizes, s);
}

/** Take array slices at the given indices of the specified axis. */
array take(
    const array& a,
    const array& indices,
    int axis,
    StreamOrDevice s = {});

/** Take array entries at the given indices treating the array as flattened. */
array take(const array& a, const array& indices, StreamOrDevice s = {});

/** Take array entries given indices along the axis */
array take_along_axis(
    const array& a,
    const array& indices,
    int axis,
    StreamOrDevice s = {});

/** Scatter updates to the given indices.
 *
 * The parameters ``indices`` and ``axes`` determine the locations of ``a``
 * that are updated with the values in ``updates``. Assuming 1-d ``indices``
 * for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which
 * the values in ``updates`` will be applied. Note each array in
 * ``indices`` is assigned to a corresponding axis and hence ``indices.size() ==
 * axes.size()``. If an index/axis pair is not provided then indices along that
 * axis are assumed to be zero.
 *
 * Note the rank of ``updates`` must be equal to the sum of the rank of the
 * broadcasted ``indices`` and the rank of ``a``. In other words, assuming the
 * arrays in ``indices`` have the same shape, ``updates.ndim() ==
 * indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates``
 * correspond to the indices, and the remaining ``a.ndim()`` dimensions are the
 * values that will be applied to the given location in ``a``.
 *
 * For example:
 *
 * @code
 * auto in = zeros({4, 4}, float32);
 * auto indices = array({2});
 * auto updates = reshape(arange(1, 3, float32), {1, 1, 2});
 * std::vector<int> axes{0};
 *
 * auto out = scatter(in, {indices}, updates, axes);
 * @endcode
 *
 * will produce:
 *
 * @code
 * array([[0, 0, 0, 0],
 *        [0, 0, 0, 0],
 *        [1, 2, 0, 0],
 *        [0, 0, 0, 0]], dtype=float32)
 * @endcode
 *
 * This scatters the two-element row vector ``[1, 2]`` starting at the ``(2,
 * 0)`` position of ``a``.
 *
 * Adding another element to ``indices`` will scatter into another location of
 * ``a``. We also have to add an another update for the new index:
 *
 * @code
 * auto in = zeros({4, 4}, float32);
 * auto indices = array({2, 0});
 * auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
 * std::vector<int> axes{0};
 *
 * auto out = scatter(in, {indices}, updates, axes):
 * @endcode
 *
 * will produce:
 *
 * @code
 * array([[3, 4, 0, 0],
 *        [0, 0, 0, 0],
 *        [1, 2, 0, 0],
 *        [0, 0, 0, 0]], dtype=float32)
 * @endcode
 *
 * To control the scatter location on an additional axis, add another index
 * array to ``indices`` and another axis to ``axes``:
 *
 * @code
 * auto in = zeros({4, 4}, float32);
 * auto indices = std::vector{array({2, 0}), array({1, 2})};
 * auto updates = reshape(arange(1, 5, float32), {2, 1, 2});
 * std::vector<int> axes{0, 1};
 *
 * auto out = scatter(in, indices, updates, axes);
 * @endcode
 *
 * will produce:
 *
 * @code
 * array([[0, 0, 3, 4],
 *       [0, 0, 0, 0],
 *       [0, 1, 2, 0],
 *       [0, 0, 0, 0]], dtype=float32)
 * @endcode
 *
 * Items in indices are broadcasted together. This means:
 *
 * @code
 * auto indices = std::vector{array({2, 0}), array({1})};
 * @endcode
 *
 * is equivalent to:
 *
 * @code
 * auto indices = std::vector{array({2, 0}), array({1, 1})};
 * @endcode
 *
 * Note, ``scatter`` does not perform bounds checking on the indices and
 * updates.  Out-of-bounds accesses on ``a`` are undefined and typically result
 * in unintended or invalid memory writes.
 */
array scatter(
    const array& a,
    const std::vector<array>& indices,
    const array& updates,
    const std::vector<int>& axes,
    StreamOrDevice s = {});
inline array scatter(
    const array& a,
    const array& indices,
    const array& updates,
    int axis,
    StreamOrDevice s = {}) {
  return scatter(a, {indices}, updates, std::vector<int>{axis}, s);
}

/** Scatter and add updates to given indices */
array scatter_add(
    const array& a,
    const std::vector<array>& indices,
    const array& updates,
    const std::vector<int>& axes,
    StreamOrDevice s = {});
inline array scatter_add(
    const array& a,
    const array& indices,
    const array& updates,
    int axis,
    StreamOrDevice s = {}) {
  return scatter_add(a, {indices}, updates, std::vector<int>{axis}, s);
}

/** Scatter and prod updates to given indices */
array scatter_prod(
    const array& a,
    const std::vector<array>& indices,
    const array& updates,
    const std::vector<int>& axes,
    StreamOrDevice s = {});
inline array scatter_prod(
    const array& a,
    const array& indices,
    const array& updates,
    int axis,
    StreamOrDevice s = {}) {
  return scatter_prod(a, {indices}, updates, std::vector<int>{axis}, s);
}

/** Scatter and max updates to given linear indices */
array scatter_max(
    const array& a,
    const std::vector<array>& indices,
    const array& updates,
    const std::vector<int>& axes,
    StreamOrDevice s = {});
inline array scatter_max(
    const array& a,
    const array& indices,
    const array& updates,
    int axis,
    StreamOrDevice s = {}) {
  return scatter_max(a, {indices}, updates, std::vector<int>{axis}, s);
}
/** Scatter and min updates to given linear indices */
array scatter_min(
    const array& a,
    const std::vector<array>& indices,
    const array& updates,
    const std::vector<int>& axes,
    StreamOrDevice s = {});
inline array scatter_min(
    const array& a,
    const array& indices,
    const array& updates,
    int axis,
    StreamOrDevice s = {}) {
  return scatter_min(a, {indices}, updates, std::vector<int>{axis}, s);
}

/** Square root the elements of an array. */
array sqrt(const array& a, StreamOrDevice s = {});

/** Square root and reciprocal the elements of an array. */
array rsqrt(const array& a, StreamOrDevice s = {});

/** Softmax of an array. */
array softmax(
    const array& a,
    const std::vector<int>& axes,
    bool precise = false,
    StreamOrDevice s = {});

/** Softmax of an array. */
array softmax(const array& a, bool precise = false, StreamOrDevice s = {});

/** Softmax of an array. */
inline array
softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) {
  return softmax(a, std::vector<int>{axis}, precise, s);
}

/** Raise elements of a to the power of b element-wise */
array power(const array& a, const array& b, StreamOrDevice s = {});

/** Cumulative sum of an array. */
array cumsum(
    const array& a,
    int axis,
    bool reverse = false,
    bool inclusive = true,
    StreamOrDevice s = {});

/** Cumulative product of an array. */
array cumprod(
    const array& a,
    int axis,
    bool reverse = false,
    bool inclusive = true,
    StreamOrDevice s = {});

/** Cumulative max of an array. */
array cummax(
    const array& a,
    int axis,
    bool reverse = false,
    bool inclusive = true,
    StreamOrDevice s = {});

/** Cumulative min of an array. */
array cummin(
    const array& a,
    int axis,
    bool reverse = false,
    bool inclusive = true,
    StreamOrDevice s = {});

/** General convolution with a filter */
array conv_general(
    array input,
    array weight,
    std::vector<int> stride = {},
    std::vector<int> padding_lo = {},
    std::vector<int> padding_hi = {},
    std::vector<int> kernel_dilation = {},
    std::vector<int> input_dilation = {},
    int groups = 1,
    bool flip = false,
    StreamOrDevice s = {});

/** General convolution with a filter */
inline array conv_general(
    const array& input,
    const array& weight,
    std::vector<int> stride = {},
    std::vector<int> padding = {},
    std::vector<int> kernel_dilation = {},
    std::vector<int> input_dilation = {},
    int groups = 1,
    bool flip = false,
    StreamOrDevice s = {}) {
  return conv_general(
      /* const array& input = */ input,
      /* const array& weight = */ weight,
      /* std::vector<int> stride = */ stride,
      /* std::vector<int> padding_lo = */ padding,
      /* std::vector<int> padding_hi = */ padding,
      /* std::vector<int> kernel_dilation = */ kernel_dilation,
      /* std::vector<int> input_dilation = */ input_dilation,
      /* int groups = */ groups,
      /* bool flip = */ flip,
      /* StreamOrDevice s = */ s);
}

/** 1D convolution with a filter */
array conv1d(
    const array& input,
    const array& weight,
    int stride = 1,
    int padding = 0,
    int dilation = 1,
    int groups = 1,
    StreamOrDevice s = {});

/** 2D convolution with a filter */
array conv2d(
    const array& input,
    const array& weight,
    const std::pair<int, int>& stride = {1, 1},
    const std::pair<int, int>& padding = {0, 0},
    const std::pair<int, int>& dilation = {1, 1},
    int groups = 1,
    StreamOrDevice s = {});

/** 3D convolution with a filter */
array conv3d(
    const array& input,
    const array& weight,
    const std::tuple<int, int, int>& stride = {1, 1, 1},
    const std::tuple<int, int, int>& padding = {0, 0, 0},
    const std::tuple<int, int, int>& dilation = {1, 1, 1},
    int groups = 1,
    StreamOrDevice s = {});

/** Quantized matmul multiplies x with a quantized matrix w*/
array quantized_matmul(
    const array& x,
    const array& w,
    const array& scales,
    const array& biases,
    bool transpose = true,
    int group_size = 64,
    int bits = 4,
    StreamOrDevice s = {});

/** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize(
    const array& w,
    int group_size = 64,
    int bits = 4,
    StreamOrDevice s = {});

/** Dequantize a matrix produced by quantize() */
array dequantize(
    const array& w,
    const array& scales,
    const array& biases,
    int group_size = 64,
    int bits = 4,
    StreamOrDevice s = {});

/** Compute matrix products with matrix-level gather. */
array gather_qmm(
    const array& x,
    const array& w,
    const array& scales,
    const array& biases,
    std::optional<array> lhs_indices = std::nullopt,
    std::optional<array> rhs_indices = std::nullopt,
    bool transpose = true,
    int group_size = 64,
    int bits = 4,
    StreamOrDevice s = {});

/** Returns a contraction of a and b over multiple dimensions. */
array tensordot(
    const array& a,
    const array& b,
    const int axis = 2,
    StreamOrDevice s = {});

array tensordot(
    const array& a,
    const array& b,
    const std::vector<int>& axes_a,
    const std::vector<int>& axes_b,
    StreamOrDevice s = {});

/** Compute the outer product of two vectors. */
array outer(const array& a, const array& b, StreamOrDevice s = {});

/** Compute the inner product of two vectors. */
array inner(const array& a, const array& b, StreamOrDevice s = {});

/** Compute D = beta * C + alpha * (A @ B) */
array addmm(
    array c,
    array a,
    array b,
    const float& alpha = 1.f,
    const float& beta = 1.f,
    StreamOrDevice s = {});

/** Compute matrix product with block masking */
array block_masked_mm(
    array a,
    array b,
    int block_size,
    std::optional<array> mask_out = std::nullopt,
    std::optional<array> mask_lhs = std::nullopt,
    std::optional<array> mask_rhs = std::nullopt,
    StreamOrDevice s = {});

/** Compute matrix product with matrix-level gather */
array gather_mm(
    array a,
    array b,
    std::optional<array> lhs_indices = std::nullopt,
    std::optional<array> rhs_indices = std::nullopt,
    StreamOrDevice s = {});

/** Extract a diagonal or construct a diagonal array */
array diagonal(
    const array& a,
    int offset = 0,
    int axis1 = 0,
    int axis2 = 1,
    StreamOrDevice s = {});

/** Extract diagonal from a 2d array or create a diagonal matrix. */
array diag(const array& a, int k = 0, StreamOrDevice s = {});

/** Return the sum along a specified diagonal in the given array. */
array trace(
    const array& a,
    int offset,
    int axis1,
    int axis2,
    Dtype dtype,
    StreamOrDevice s = {});
array trace(
    const array& a,
    int offset,
    int axis1,
    int axis2,
    StreamOrDevice s = {});
array trace(const array& a, StreamOrDevice s = {});

/**
 * Implements the identity function but allows injecting dependencies to other
 * arrays. This ensures that these other arrays will have been computed
 * when the outputs of this function are computed.
 */
std::vector<array> depends(
    const std::vector<array>& inputs,
    const std::vector<array>& dependencies);

/** convert an array to an atleast ndim array */
array atleast_1d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_1d(
    const std::vector<array>& a,
    StreamOrDevice s = {});
array atleast_2d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_2d(
    const std::vector<array>& a,
    StreamOrDevice s = {});
array atleast_3d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_3d(
    const std::vector<array>& a,
    StreamOrDevice s = {});

/**
 * Extract the number of elements along some axes as a scalar array. Used to
 * allow shape dependent shapeless compilation (pun intended).
 */
array number_of_elements(
    const array& a,
    std::vector<int> axes,
    bool inverted,
    Dtype dtype = int32,
    StreamOrDevice s = {});

array conjugate(const array& a, StreamOrDevice s = {});

/** Bitwise and. */
array bitwise_and(const array& a, const array& b, StreamOrDevice s = {});
array operator&(const array& a, const array& b);

/** Bitwise inclusive or. */
array bitwise_or(const array& a, const array& b, StreamOrDevice s = {});
array operator|(const array& a, const array& b);

/** Bitwise exclusive or. */
array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {});
array operator^(const array& a, const array& b);

/** Shift bits to the left. */
array left_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator<<(const array& a, const array& b);

/** Shift bits to the right. */
array right_shift(const array& a, const array& b, StreamOrDevice s = {});
array operator>>(const array& a, const array& b);

array view(const array& a, const Dtype& dtype, StreamOrDevice s = {});
/** @} */

} // namespace mlx::core
